-
Notifications
You must be signed in to change notification settings - Fork 769
refactor: Migrate to usize indexing
#4273
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
laggui
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Changes in burn LGTM! Pending rev updates once linked PRs are merged.
| burn_tensor::bf16, | ||
| "../autodiff/mod.rs", | ||
| ["vulkan", "metal"] // ["cuda", "rocm"] TODO | ||
| ["metal"] // ["cuda", "rocm"] TODO, ["vulkan"] only supports bf16 for matmul |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Whoops, I added bf16 for vulkan but clearly the tests don't pass 😅 thanks for fixing.
[Unrelated to this PR]
I vaguely remember adding it when refactoring the tests since it was a supported global type for vulkan in cubecl. Also reflected by B::supports_dtype:
burn/crates/burn-wgpu/src/lib.rs
Lines 117 to 145 in 8682709
| #[test] | |
| fn should_support_dtypes() { | |
| type B = Wgpu; | |
| let device = Default::default(); | |
| assert!(B::supports_dtype(&device, DType::F32)); | |
| assert!(B::supports_dtype(&device, DType::I64)); | |
| assert!(B::supports_dtype(&device, DType::I32)); | |
| assert!(B::supports_dtype(&device, DType::U64)); | |
| assert!(B::supports_dtype(&device, DType::U32)); | |
| assert!(B::supports_dtype( | |
| &device, | |
| DType::QFloat(CubeTensor::<WgpuRuntime>::default_scheme()) | |
| )); | |
| // Registered as supported type but we don't actually use it? | |
| assert!(B::supports_dtype(&device, DType::Bool)); | |
| #[cfg(feature = "vulkan")] | |
| { | |
| assert!(B::supports_dtype(&device, DType::F16)); | |
| assert!(B::supports_dtype(&device, DType::BF16)); | |
| assert!(B::supports_dtype(&device, DType::I16)); | |
| assert!(B::supports_dtype(&device, DType::I8)); | |
| assert!(B::supports_dtype(&device, DType::U16)); | |
| assert!(B::supports_dtype(&device, DType::U8)); | |
| assert!(!B::supports_dtype(&device, DType::F64)); | |
| assert!(!B::supports_dtype(&device, DType::Flex32)); | |
| } |
maybe we should have a better way to represent the actual supported types?
burn/crates/burn-cubecl/src/backend.rs
Lines 83 to 88 in 8682709
| fn supports_dtype(device: &Self::Device, dtype: DType) -> bool { | |
| let client = R::client(device); | |
| let ty: StorageType = dtype.into(); | |
| client.properties().supports_type(ty.elem_type()) | |
| } |
That way, tested dtypes can actually reflected supported dtypes.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The supports_type method on the CubeCL side only checks if a type is supported in any way (in this case, it's supported for conversion, as a type for buffers, and for dot product on Intel, along with tensor core instructions). It's kinda tough though because there's no good way to express that in just a single boolean (hence why the TypeUsage enum set exists in CubeCL).
This is how it's registered for Vulkan
if let Some(bfloat16) = ext_feat.bfloat16 {
if bfloat16.shader_b_float16_type == TRUE {
register(
ElemType::Float(FloatKind::BF16).into(),
TypeUsage::Conversion | TypeUsage::Buffer,
);
}
if bfloat16.shader_b_float16_dot_product == TRUE {
register(
ElemType::Float(FloatKind::BF16).into(),
TypeUsage::DotProduct.into(),
);
}
}There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So it's supported for matmul and casting, but none of the other ops.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The
supports_typemethod on the CubeCL side only checks if a type is supported in any way
Yeah and for the first draft I simply mirrored that for the backends, but I think it should be refined.
It's kinda tough though because there's no good way to express that in just a single boolean (hence why the TypeUsage enum set exists in CubeCL).
That's a good point. It's still useful to query backend supported types for burn, so maybe we should also define an enum similar to TypeUsage? (without atomics, and perhaps consolidate conversion / buffer into "storage" variant or similar).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Regardless, that's something for a separate issue/PR I think.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh yes entirely, that's why I prefaced my initial comment with [Unrelated to this PR] 😄
Just wanted to get your thoughts since this was related to the bf16 change
Codecov Report❌ Patch coverage is Additional details and impacted files@@ Coverage Diff @@
## main #4273 +/- ##
==========================================
+ Coverage 68.85% 68.87% +0.02%
==========================================
Files 1405 1405
Lines 167686 167607 -79
==========================================
- Hits 115456 115442 -14
+ Misses 52230 52165 -65 ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
Pull Request Template
Checklist
cargo run-checkscommand has been executed.Related Issues/PRs
Migrates to changes in tracel-ai/cubecl#1127 and tracel-ai/cubek#51
Changes
Migrates cubecl kernels and fusion to
usizeindexingTesting
The test suite runs successfully, though 64-bit indexing is not yet enabled.